-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][tosa] Add support for BF16 in tosa.resize
legalization
#158616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Extend the resize linalg legalization functionality with BF16 support and in accordance to the TOSA specification. Signed-off-by: Georgios Pinitas <[email protected]>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Georgios Pinitas (GeorgeARM) ChangesExtend the resize linalg legalization functionality with BF16 support and in accordance to the TOSA specification. Full diff: https://github.com/llvm/llvm-project/pull/158616.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0a6f2477560a1..1955eec9964eb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1827,8 +1827,8 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
auto resultTy = cast<ShapedType>(op.getType());
auto resultETy = resultTy.getElementType();
- bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
- auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
+ bool floatingPointMode = isa<FloatType>(resultETy);
+ auto floatTy = resultETy;
auto imageH = inputTy.getShape()[1];
auto imageW = inputTy.getShape()[2];
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
index ff2cbbc0b3938..6998aee45b887 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
@@ -12,6 +12,18 @@ func.func @unary_resize_nearest_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x
// -----
+// CHECK-LABEL: @unary_resize_nearest_bf16
+func.func @unary_resize_nearest_bf16(%arg0 : tensor<3x1x1x7xbf16>) -> tensor<3x1x1x7xbf16> {
+ %scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<3x1x1x7xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xbf16>
+ // CHECK: return %arg0
+ return %resize : tensor<3x1x1x7xbf16>
+}
+
+// -----
+
// CHECK-LABEL: @unary_resize_nearest_fp16
func.func @unary_resize_nearest_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
%scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
@@ -36,6 +48,18 @@ func.func @unary_resize_bilinear_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1
// -----
+// CHECK-LABEL: @unary_resize_bilinear_bf16
+func.func @unary_resize_bilinear_bf16(%arg0 : tensor<3x1x1x7xbf16>) -> tensor<3x1x1x7xbf16> {
+ %scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<3x1x1x7xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x1x7xbf16>
+ // CHECK: return %arg0
+ return %resize : tensor<3x1x1x7xbf16>
+}
+
+// -----
+
// CHECK-LABEL: @unary_resize_bilinear_fp16
func.func @unary_resize_bilinear_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
%scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
@@ -60,6 +84,26 @@ func.func @unary_resize_nearest_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x1x1x7
// -----
+// CHECK-LABEL: @broadcast_resize_nearest_bf16
+func.func @broadcast_resize_nearest_bf16(%arg0 : tensor<3x1x1x7xbf16>) -> tensor<3x1x5x7xbf16> {
+ // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0
+ // CHECK-NEXT{literal}: [[0], [1, 2, 3]] : tensor<3x1x1x7xbf16> into tensor<3x7xbf16>
+ // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x1x5x7xbf16>
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
+ // CHECK-SAME: indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ // CHECK-SAME: ins(%[[COLLAPSE]] : tensor<3x7xbf16>) outs(%[[EMPTY]] : tensor<3x1x5x7xbf16>)
+ // CHECK: ^bb0(%[[IN:.+]]: bf16, %[[OUT:.+]]: bf16):
+ // CHECK: linalg.yield %[[IN]] : bf16
+ %scale = tosa.const_shape { values = dense<[2, 1, 3, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<3x1x1x7xbf16>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x1x5x7xbf16>
+
+ return %resize : tensor<3x1x5x7xbf16>
+}
+
+// -----
+
// CHECK-LABEL: @broadcast_resize_nearest_f32
func.func @broadcast_resize_nearest_f32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x5x7xf32> {
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Curious if there are more bf16 changes needed? Should we create a tracking issue similar to #157378 for ext-bf16 support in tosa-to-linalg? |
Extend the resize linalg legalization functionality with BF16 support and in accordance to the TOSA specification.